Skip to content

Conversation

@sarthak-amd
Copy link
Collaborator

@sarthak-amd sarthak-amd commented Oct 16, 2025

Description

The Fused Cross Entropy Triton Kernel currently has 2 bugs

  1. if ignore_idx is not None`, the loss should be computed only over valid tokens and not all tokens (new fix)
  2. gradient scaling when reduce_loss=False. (This is already fixed in upstream)
    • If reduced loss=False, we should compute per token loss and not reduce it else it would shrink the gradients by 1/N giving wrong (higher) loss.
    • if reduce_loss=False, grad_output is a tensor, not a scalar. We need to load 1 value per row instead of just a scalar.

This fix is validated on Llama3.1 8B model for Pre-training.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

@sarthak-amd sarthak-amd marked this pull request as ready for review October 16, 2025 12:31
@sarthak-amd sarthak-amd changed the title Loss Scaling and Vanishing Grads Fused Cross Entropy Triton - Loss Scaling and Vanishing Grads Bugfix Oct 16, 2025
@wenchenvincent
Copy link
Collaborator

@sarthak-amd Could you post the PR for the upstream fix?

@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no real change in this file. Let's keep this file intact and then we don't need to add the AMD copyright statement.

@sarthak-amd
Copy link
Collaborator Author

@sarthak-amd Could you post the PR for the upstream fix?

NVIDIA/TransformerEngine@e9a5fa4 @wenchenvincent

@wenchenvincent
Copy link
Collaborator

@sarthak-amd Could you post the PR for the upstream fix?

NVIDIA/TransformerEngine@e9a5fa4 @wenchenvincent

Another fix came from the upstream PR NVIDIA/TransformerEngine#1879. Is the change of test in that PR also reflected?

@wenchenvincent
Copy link
Collaborator

For the fix for ignore_idx, is there a test for it (without the fix, the test would fail)?

Copy link
Collaborator

@wenchenvincent wenchenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarthak-amd Could you refactor the PR as 3 commits:

  • 2 commits would be cherrypicking from the upstream PRs.
  • 1 commit for the ignore_idx with a test to cover it.

This way the PR would be very clear and easy to understand.

@wenchenvincent
Copy link
Collaborator

@sarthak-amd Could you address the comments? Also, please rebase upon latest dev so that hot fixes for sgpu tests could pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants